"""
RAG Reasoning - Provides reasoning capabilities using local RAG knowledge base.

This module replaces LLM provider calls with local RAG knowledge base queries.
"""

import json
import re
import textwrap
import logging
from typing import Any, Dict, List, Optional

from droidrun.tools.rag import query_rag

# Set up logger
logger = logging.getLogger("droidrun")

class RAGReasoner:
    """RAG-based reasoner for Android automation that uses local knowledge base instead of LLMs."""
    
    def __init__(
        self,
        temperature: float = 0.2,
        max_tokens: int = 2000
    ):
        """Initialize the RAG reasoner.
        
        Args:
            temperature: Temperature parameter (kept for compatibility)
            max_tokens: Maximum tokens parameter (kept for compatibility)
        """
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.token_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}

    def get_token_usage_stats(self) -> Dict[str, int]:
        """Get current token usage statistics."""
        return self.token_usage

    async def preprocess_ui(
        self,
        goal: str,
        history: List[Dict[str, Any]],
        current_ui_state: Optional[str] = None,
        screenshot_data: Optional[bytes] = None
    ) -> Dict[str, Any]:
        """Preprocess the UI state using RAG to get a simplified view of clickable elements.
        
        Args:
            goal: The automation goal
            history: List of previous steps as dictionaries
            current_ui_state: Current UI state with clickable elements (in JSON format)
            screenshot_data: Optional screenshot data in bytes
            
        Returns:
            Dictionary containing processed UI information with simplified clickable elements
        """
        # Get the last action step from history
        last_action = None
        for step in reversed(history):
            if step.get("type") == "action":
                last_action = step
                break

        # Create a query for the RAG system
        query = f"""
        Goal: {goal}
        Last action: {last_action["content"] if last_action else "No previous action"}
        Current UI state: {current_ui_state}
        
        Create a comprehensive narrative description of the current UI state.
        Focus on spatial relationships, interactive capabilities, and semantic meaning of elements.
        Describe the interface as if you are the eyes for someone who cannot see but needs to understand and interact with it.
        """
        
        try:
            # Call the RAG system
            response = query_rag(query)
            return response
            
        except Exception as e:
            logger.error(f"Error in UI preprocessing: {e}")
            return {"elements": []}

    async def reason(
        self,
        goal: str,
        history: List[Dict[str, Any]],
        available_tools: Optional[List[str]] = None,
        current_ui_state: Optional[str] = None,
        current_phone_state: Optional[str] = None,
        screenshot_data: Optional[bytes] = None,
        memories: Optional[List[Dict[str, str]]] = None
    ) -> Dict[str, Any]:
        """Generate a reasoning step using the RAG knowledge base.
        
        Args:
            goal: The automation goal
            history: List of previous steps as dictionaries
            available_tools: Optional list of available tool names
            screenshot_data: Optional bytes containing the latest screenshot
            memories: Optional list of memories from the memory store
        
        Returns:
            Dictionary with next reasoning step, including thought,
            action, and any parameters
        """
        # Print current token usage stats before making the call
        logger.info(f"Token usage before RAG call: {self.get_token_usage_stats()}")
        
        # Construct the query
        query = self._create_query(
            goal,
            available_tools,
            history,
            memories,
            current_ui_state,
            current_phone_state
        )
        
        try:
            # Call the RAG system
            response = query_rag(query)
            
            # Parse the response
            result = self._parse_response(response)
            
            # Print updated token usage stats after the call
            logger.info(f"Token usage after RAG call: {self.get_token_usage_stats()}")
            
            return result
            
        except Exception as e:
            logger.error(f"Error in RAG reasoning: {e}")
            # Return a fallback response
            return {
                "thought": f"RAG reasoning error: {e}",
                "action": "error",
                "parameters": {}
            }
    
    def _create_query(self,
        goal: str,
        available_tools: Optional[List[str]] = None,
        history: Optional[List[Dict[str, Any]]] = None,
        memories: Optional[List[Dict[str, str]]] = None,
        current_ui_state: Optional[str] = None,
        current_phone_state: Optional[str] = None,
        ) -> str:
        """Create the query for the RAG system.
        
        Args:
            available_tools: Optional list of available tool names   
            history: List of previous steps
            memories: Optional list of memories from the memory store
        
        Returns:
            Query string
        """
        # Base query
        query = f"""
        Goal: {goal}
        
        Based on the following information, what should be the next action to take on the Android device?
        Return your response in JSON format with the following fields:
        - thought: Your detailed reasoning about the current state and what to do next
        - action: The name of the tool to execute (use EXACT tool name without any parentheses)
        - parameters: A dictionary of parameters to pass to the tool
        
        Available tools:
        {self._add_tools_to_query(available_tools)}
        
        Memories:
        {self._add_memories_to_query(memories)}
        
        History:
        {self._add_history_to_query(history)}
        
        Phone state:
        {current_phone_state}
        
        UI structure:
        {current_ui_state}
        """

        return query

    def _add_tools_to_query(self, available_tools: Optional[List[str]]) -> str:
        """Add available tools information to the query.
        
        Args:
            available_tools: Optional list of available tool names
            
        Returns:
            String containing tools documentation
        """
        from ..tool_docs import tool_docs
        if not available_tools:
            return ""

        tools_prompt = ""
        
        # Only include docs for available tools
        for tool in available_tools:
            if tool in tool_docs:
                tools_prompt += f"- {tool_docs[tool]}\n"
            else:
                tools_prompt += f"- {tool} (parameters unknown)\n"
                
        return tools_prompt
    
    def _add_memories_to_query(self, memories: Optional[List[Dict[str, str]]]) -> str:
        """Add memories information to the query.
        
        Args:
            memories: Optional list of memories from the memory store
            
        Returns:
            String containing formatted memories
        """
        if not memories or len(memories) == 0:
            return ""
            
        memories_prompt = ""
        for i, memory in enumerate(memories, 1):
            memories_prompt += f"{i}. {memory['content']}\n"
        
        return memories_prompt
    
    def _add_history_to_query(self, history: Optional[List[Dict[str, Any]]]) -> str:
        """Add recent history information to the query.
        
        Args:
            history: Optional list of previous steps
            
        Returns:
            String containing formatted history in reverse order (most recent first)
        """
        if not history:
            return ""
            
        # Filter out GOAL type steps
        filtered_history = [step for step in history if step.get("type", "").upper() != "GOAL"]
            
        # Get only the last 50 steps (if available)
        recent_history = filtered_history[-50:] if len(filtered_history) >= 50 else filtered_history
        
        history_prompt = ""
        # Add the recent history steps in reverse order
        for step in reversed(recent_history):
            step_type = step.get("type", "").upper()
            content = step.get("content", "")
            step_number = step.get("step_number", 0)
            history_prompt += f"Step {step_number} - {step_type}: {content}\n"
        
        history_prompt += "\n"
        return history_prompt
    
    def _parse_response(self, response: str) -> Dict[str, Any]:
        """Parse the RAG response into a structured format.
        
        Args:
            response: RAG response string
        
        Returns:
            Dictionary with parsed response
        """
        try:
            # 首先尝试从 markdown 代码块中提取 JSON
            import re
            json_match = re.search(r'```(?:json)?\s*([\s\S]*?)\s*```', response)
            if json_match:
                try:
                    # 尝试解析提取出的代码块内容
                    json_content = json_match.group(1)
                    data = json.loads(json_content)
                    logger.info("Successfully parsed JSON from code block")
                except json.JSONDecodeError:
                    # 如果代码块内容不是有效 JSON，尝试直接解析原始响应
                    data = json.loads(response)
            else:
                # 如果没有找到代码块，尝试直接解析原始响应
                data = json.loads(response)
            
            # 确保所需字段存在
            if "thought" not in data:
                data["thought"] = "No thought provided"
            if "action" not in data:
                data["action"] = "no_action"
            if "parameters" not in data:
                data["parameters"] = {}
                
            return data
        except json.JSONDecodeError:
            # 如果无法解析 JSON，尝试使用正则表达式提取字段
            thought_match = re.search(r'thought["\']\s*[:]?\s*["\'](.*?)["\'"\n]', response, re.DOTALL | re.IGNORECASE)
            action_match = re.search(r'action["\']\s*[:]?\s*["\'](.*?)["\'"\n,]', response, re.IGNORECASE)
            
            # 更灵活的参数提取模式
            # 1. 尝试提取标准 JSON 格式的参数
            params_match = re.search(r'parameters["\']\s*[:]?\s*(\{.*?\})', response, re.DOTALL | re.IGNORECASE)
            
            # 如果找不到，尝试提取参数名和值对
            if not params_match:
                # 查找可能的参数键值对，格式如 "keycode": "3" 或 keycode: 3 或 keycode: "HOME"
                param_pairs = re.findall(r'["\']?(\w+)["\']?\s*[:]\s*["\']?([\w\.]+)["\']?', response)
                params_dict = {}
                for key, value in param_pairs:
                    # 尝试将纯数字参数转换为整数
                    if value.isdigit():
                        params_dict[key] = int(value)
                    else:
                        params_dict[key] = value
            
            thought = thought_match.group(1).strip() if thought_match else "Failed to parse thought"
            action = action_match.group(1).strip() if action_match else "no_action"
            
            # 处理参数
            params = {}
            if params_match:
                try:
                    params_str = params_match.group(1)
                    # 替换单引号为双引号以获取有效的 JSON
                    params_str = params_str.replace("'", "\"")
                    params = json.loads(params_str)
                except json.JSONDecodeError:
                    logger.warning("Failed to parse parameters JSON")
                    # 如果 JSON 解析失败，但我们有键值对，则使用它们
                    if 'params_dict' in locals():
                        params = params_dict
            elif 'params_dict' in locals():
                # 使用从正则表达式提取的键值对
                params = params_dict
            
            result = {
                "thought": thought,
                "action": action,
                "parameters": params
            }
            
            # 针对常见工具添加默认参数
            if action == "press_key" and "keycode" not in params:
                if "home" in thought.lower() or "主页" in thought.lower():
                    params["keycode"] = "HOME"
                elif "back" in thought.lower() or "返回" in thought.lower():
                    params["keycode"] = "BACK"
                else:
                    params["keycode"] = "HOME"  # 默认为 HOME
                logger.warning(f"Missing keycode parameter for press_key, defaulting to {params['keycode']}")
            
            # 为tap工具添加默认参数处理
            elif action == "tap" and "index" not in params:
                # 尝试从思考中提取index
                index_match = re.search(r'(?:index|索引|编号)[=: ]+(\d+)', thought.lower())
                if index_match:
                    params["index"] = int(index_match.group(1))
                else:
                    params["index"] = 0  # 默认为索引0
                # 添加默认的longpress参数
                if "longpress" not in params:
                    params["longpress"] = False
                logger.warning(f"Missing index parameter for tap, defaulting to {params['index']}")
            
            return result
            
    # Special method to handle identity questions
    def is_identity_question(self, query: str) -> bool:
        """Check if a query is asking about the model's identity.
        
        Args:
            query: The query to check
            
        Returns:
            True if the query is asking about identity, False otherwise
        """
        identity_patterns = [
            r"你是谁",
            r"你叫什么",
            r"你的名字是什么",
            r"你是什么模型",
            r"你是什么大模型",
            r"你是哪个模型",
            r"你是什么语言模型",
            r"你是什么人工智能",
            r"你是什么ai",
            r"你是什么llm",
            r"你是什么assistant",
            r"你是什么助手",
            r"who are you",
            r"what are you",
            r"what model are you",
            r"what's your name",
            r"what is your name",
            r"which model are you",
            r"what language model are you",
            r"what ai are you",
            r"what assistant are you"
        ]
        
        for pattern in identity_patterns:
            if re.search(pattern, query.lower()):
                return True
        return False 